# Copyright 2018- The Pixie Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import argparse
import ast
import os
from copy import deepcopy
from pathlib import Path

from presidio_evaluator.data_generator.faker_extensions.data_objects import \
    FakerSpansResult
from presidio_evaluator.data_objects import InputSample
from presidio_evaluator.evaluation import Evaluator
from presidio_evaluator.models import FlairTrainer
from presidio_evaluator.validation import save_to_json, split_dataset


# entity mappings for flair sequence labelling model
# "O" means non-entity that flair will not be trained to predict
PRIVY_ENTITIES = {
    "person": "PER",
    "name_male": "PER",
    "name_female": "PER",
    "first_name": "PER",
    "first_name_male": "PER",
    "first_name_female": "PER",
    "first_name_nonbinary": "PER",
    "last_name": "PER",
    "last_name_male": "PER",
    "last_name_female": "PER",

    "address": "LOC",
    "street_address": "LOC",
    "secondary_address": "LOC",
    "zipcode": "LOC",
    "building_number": "LOC",
    "street_name": "LOC",
    "airport_name": "LOC",
    "airport_iata": "LOC",
    "airport_icao": "LOC",

    "country": "LOC",
    "country_code": "LOC",
    "state": "LOC",
    "state_abbr": "LOC",
    "city": "LOC",

    "coordinate": "LOC",
    "longitude": "LOC",
    "latitude": "LOC",

    "nationality": "NRP",
    "nation_woman": "NRP",
    "nation_man": "NRP",
    "nation_plural": "NRP",
    "religion": "NRP",

    "date": "DATE_TIME",
    "date_time": "DATE_TIME",
    "date_of_birth": "DATE_TIME",
    "day_of_week": "DATE_TIME",
    "year": "DATE_TIME",
    "month": "DATE_TIME",

    "url": "O",
    "domain_name": "O",

    "credit_card_number": "O",
    "credit_card_expire": "DATE_TIME",

    "iban": "O",
    "bban": "O",
    "phone_number": "O",
    "ssn": "O",
    "passport": "O",
    "driver_license": "O",
    "ip_address": "O",
    "itin": "O",
    "email": "O",

    "organization": "ORG",
    "company": "ORG",
    "airline": "ORG",

    "job": "O",
    "prefix": "O",
    "prefix_male": "O",
    "prefix_female": "O",
    "gender": "O",

    "imei": "O",
    "password": "O",
    "license_plate": "O",
    "mac_address": "O",
    "age": "O",

    "currency_code": "O",
    "aba": "O",
    "swift": "O",

    "string": "O",
    "boolean": "O",
    "color": "O",
    "random_number": "O",
    "sha1": "O",
}


class FlairDataset:
    def __init__(self, dataset_path, out_folder):
        fake_records = FakerSpansResult.load_privy_dataset(dataset_path)
        self.dataset = InputSample.convert_faker_spans(fake_records)
        # translate raw faker labels to privy entity types, marking unsupported entities as
        # non-entities 'O' (if not filtered out previously)
        self.dataset = Evaluator.align_entity_types(deepcopy(self.dataset),
                                                    entities_mapping=PRIVY_ENTITIES, allow_missing_mappings=True)
        self.output_folder = Path(out_folder)

    def train_test_val_split(self, ratios):
        """Split converted dataset into train, validation, and test sets"""
        TRAIN_TEST_VAL_RATIOS = ast.literal_eval(ratios)
        if sum(TRAIN_TEST_VAL_RATIOS) > 1:
            raise argparse.ArgumentTypeError(
                f"Ratios {ratios} must add up to 1.")
        train, test, validation = split_dataset(self.dataset, TRAIN_TEST_VAL_RATIOS)
        print(f"Train, test, validation sizes: {len(train)}, {len(test)}, {len(validation)}")
        print("Saving train, test, validation to json")
        train_path = self.output_folder / "train.json"
        test_path = self.output_folder / "test.json"
        val_path = self.output_folder / "validation.json"
        save_to_json(train, train_path)
        save_to_json(test, test_path)
        save_to_json(validation, val_path)
        return train, test, validation


def parse_args():
    parser = argparse.ArgumentParser(
        description="""Train Flair Named Entity Recognition (NER) model for token-wise PII classification using
                    utility functions adapted from presidio-research""",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument(
        "--input",
        "-i",
        required=True,
        help="""Absolute path to full input data. Must be in FakerSpansResult format generated by Privy.""",
    )

    parser.add_argument(
        "--ratios",
        "-r",
        required=False,
        default="[0.7, 0.2, 0.1]",
        nargs='+',
        help="Ratios to split data into train, validation, and test.",
    )

    parser.add_argument(
        "--out_folder",
        "-o",
        required=False,
        default=os.path.join(os.path.dirname(__file__), os.pardir),
        help="""Absolute path to where trained flair model(s) will be stored.
        By default, saves to bazel cache for this runtime.""",
    )

    parser.add_argument(
        "--transformer",
        "-t",
        required=False,
        action="store_true",
        help="""Train transformer model with roberta base"""
    )

    parser.add_argument(
        "--rnn",
        "-rnn",
        required=False,
        action="store_true",
        help="""Train rnn model"""
    )
    parser.add_argument(
        "--rnn_fast",
        "-rnn_fast",
        required=False,
        action="store_true",
        help="""Train fast rnn model"""
    )
    return parser.parse_args()


def main(args):
    # load dataset
    dataset = FlairDataset(args.input, args.out_folder)
    train, test, val = dataset.train_test_val_split(args.ratios)

    # train Flair
    trainer = FlairTrainer()
    trainer.create_flair_corpus(train, test, val, to_bio=True)

    # read privy dataset
    privy_corpus = trainer.read_corpus("./")

    # Transformer
    if args.transformer or not (args.rnn or args.rnn_fast):
        trainer.train_with_transformers(privy_corpus, mini_batch_size=32,
                                        embeddings="distilbert-base-cased", max_epochs=20)

    # GloVe embeddings
    if args.rnn:
        trainer.train_with_flair_embeddings(privy_corpus)

    if args.rnn_fast:
        trainer.train_with_flair_embeddings(privy_corpus, fast=True)


if __name__ == "__main__":
    args = parse_args()
    main(args)
